Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ralberd/path dependent adjoint #69

Merged
merged 22 commits into from
Jan 17, 2024
Merged

Conversation

ralberd
Copy link
Contributor

@ralberd ralberd commented Oct 27, 2023

Initial attempts at implementing the path-dependent adjoint sentitivity for the full FEA using J2 plasticity. Creating a draft PR so that I can get feedback as I flesh this out.

  • Unit tests checking that the jax-computed gradients of internal variable updates match analytic values for small strain linear hardening
  • "Unit test" with implementation of full path-dependent adjoint solve - Right now this doesn't test anything. I need to figure out how to check the gradient (finite differences?)
  • I plan on adding more granular unit tests checking the individual parts of the adjoint equations
  • I plan on figuring out a more efficient way of computing derivatives of internal variables as they are only dependent on quantities in their own element / integration point (computing for every other element / integration point is wasteful)

@ralberd ralberd self-assigned this Oct 27, 2023
…erelastic objectives; adding FD check for J2 objective (is failing)
@ralberd
Copy link
Contributor Author

ralberd commented Oct 31, 2023

There appears to be a bug in my adjoint system implementation for J2 plasticity. Feel free to take a look and see if anything sticks out. I will keep working on it this week.

@ralberd ralberd marked this pull request as ready for review November 10, 2023 00:42
@ralberd ralberd changed the title Draft: ralberd/path dependent adjoint ralberd/path dependent adjoint Nov 10, 2023
@ralberd
Copy link
Contributor Author

ralberd commented Nov 10, 2023

This is ready for review now. Let me know what you think!

@@ -0,0 +1,67 @@
from optimism.test.MeshFixture import *
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May I ask that you import the module directly and avoid import * in general? This pulls everythin into the top-level namespace makes it harder to see where things come from and to avoid name clashes. I realize that this idiom is used throughout optimism, but I've been trying to remove them bit by bit.

@@ -0,0 +1,103 @@
from collections import namedtuple

from optimism.JaxConfig import *
Copy link
Collaborator

@btalamini btalamini Jan 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You shouldn't need this import. It's a legacy thing I'm trying to remove. If it doesn't work without it, show me the error and I'll fix it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I happen to like import * from JaxConfig. For me, its intent was to avoid having to constantly duplicate lots of lines like: from jax import grad, import blah, ... which I found to be annoying in every file. It's an example where the DRY (dont repeat yourself) principal is in conflict with other principals. For me, I like DRY here, but I know there is a trade off. I understand that it pull in extra things, but that is essentially the point! Without it, the tops of files are more verbose and it takes longer to get to the actual work of the file without such things. If JaxConfig pull in things that may be duplicated or we don't want, we should delete them from JaxConfig.

Copy link
Collaborator

@tupek2 tupek2 Jan 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Plus, the way I think about it at least, I don't want to type jax.grad, because that is essentially hard-coding the implemenation to jax (which we are hard coded to anyway, but whatever). Its like in c++, where you can say using MatrixClass = Eigen::Matrix, but, potentially, if one is very careful, it is possible to switch out the matrix by just changing MatrixClass = Armadillo::Matrix. So, I don't think of it as polluting the namespace, I think of it as concretizing the specific grad our library is using.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand what you mean, but there are better ways of doing that. The wildcard import is discouraged pretty much universally because it makes it so hard to understand where things are coming from, and names can collide and unintentionally overwrite the one you want. It's essentially like a GOTO - you'd have to follow all the import * instances to see which grad you're ultimately getting. Even the Python project itself doesn't allow it in their style any more. If you want to use the concretization/substitution idea, there are ways to do it by importing things in the __init__.py file for the project, which is similar, but forces you to be explicit (like your MatrixClass = Armadillo::Matrix example).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, can we basically put things like grad = jax.grad, etc. in init then? Basically things which are essentially always used and in the way we generally want to use them, so we can establish a standard way of doing it. That was the intent of the .Config file. In C++, the equivalent is basically #include "types.hpp", where types are put into the same namespace as the project on purpose. That is essentially the same thing I want to do in python. The issues is that putting it in init puts in in EVERY file, where as "from optimism.types import *" is more selective. So I don't really appreciate what the issue with that is. We are not doing "from jax import *", we are including a types header.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should discuss this offline. Ryan, you can leave things as they are until we come to an agreement.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, go ahead and leave the wildcard import. We'll make a sweeping change in the future to get rid of them. Try to leave out the JaxConfig import in the future.

ResidualInverseFunctions = namedtuple('ResidualInverseFunctions',
['residual_jac_coords_vjp'])

def _compute_quadrature_point_field_gradient(u, shapeGrad):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function appears to be copied and pasted from FunctionSpace - why not use that directly?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, fixed it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The permissions on this file are wrong: it's currently marked as executable.

@@ -0,0 +1,67 @@
from optimism.test.MeshFixture import MeshFixture
from collections import namedtuple
import numpy as np
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For consistency, can you import numpy as onp here? We try to be consistent about calling the jax numpy np and the original numpy as onp. It makes it easier to understand a file at a glance.

@@ -69,7 +69,7 @@ def get_lobatto_nodes_1d(degree):
p = onp.polynomial.Legendre.basis(degree, domain=[0.0, 1.0])
dp = p.deriv()
xInterior = dp.roots()
xn = np.hstack(([0.0], xInterior, [1.0]))
xn = np.hstack((np.array([0.0]), xInterior, np.array([1.0])))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, those warnings were annoying.

@@ -0,0 +1,103 @@
from collections import namedtuple

from optimism.JaxConfig import *
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, go ahead and leave the wildcard import. We'll make a sweeping change in the future to get rid of them. Try to leave out the JaxConfig import in the future.

@ralberd ralberd merged commit a2f8efe into main Jan 17, 2024
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants